function [dec, U_f, U_m, U_f0, U_m0, U_ftot, U_mtot] = solution_backwards(n_g, T, lt, dt, H_max, disc, det_earn_f, det_earn_m, stoch_earn_f, stoch_earn_m, m_eta_f, m_eta_m, n_eta_f, n_eta_m, psi, home_prod, th1, match_qual, m_th, n_th, gamma, L_max, lambda, sigma_f, sigma_m)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% The impact of divorce laws on the equilibrium in the marriage market.
% Reynoso
% April 2024
%
% This function takes the parameters of the model and solves the
% model by backwards induction under Mutual Consent Divorce.
%
% Data: PSID 1968-1992
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%--------------------------- Preliminaries ------------------------------%

l=linspace(0,1,L_max);
lambda_grid=l;
lambda_grid(1)=0.000001;
lambda_grid(L_max)=0.99999;

dec=zeros(n_g,H_max,n_eta_f,n_eta_m,n_th,T); 
U_f=zeros(n_g,1);
U_m=zeros(n_g,1);
U_f0=zeros(n_g,1);
U_m0=zeros(n_g,1);
U_ftot=zeros(n_g,1);
U_mtot=zeros(n_g,1);

%--------------------------- Value functions -----------------------------%

parfor g=1:n_g 

[lw_grid_m, P_eta_m, ~, ~] = wageproc_fn(stoch_earn_m(g,:), m_eta_m, n_eta_m);
[lw_grid_f, P_eta_f, ~, ~] = wageproc_fn(stoch_earn_f(g,:), m_eta_f, n_eta_f);

Pwage=zeros(n_eta_f,n_eta_m,n_eta_f,n_eta_m); 

for i=1:n_eta_f 
    for j=1:n_eta_m 
        Pwage(i,j,:,:)=P_eta_f(i,:)'*P_eta_m(j,:);                                          
                                                   
   end
end

[thtau_grid, P_thtau, ~, ~] = thetaproc(match_qual(g,:), m_th, n_th);

P_wg_th=zeros(n_eta_f,n_eta_m,n_th,n_eta_f,n_eta_m,n_th); 

for i=1:n_eta_f 
    for j=1:n_eta_m 
        for k=1:n_th 
            for kp=1:n_th
        
                P_wg_th(i,j,k,:,:,kp)=Pwage(i,j,:,:)*P_thtau(k,kp); 
                                                                                     
            end
        end                                          
   end
end

%--- Autarky

A=zeros(n_eta_f,n_eta_m);
B_Af=zeros(n_eta_f,n_eta_m);
B_Am=zeros(n_eta_f,n_eta_m);
atkf=zeros(H_max,n_eta_f,n_eta_m,T); 
atkm=zeros(H_max,n_eta_f,n_eta_m,T); 
E_Af=zeros(H_max,n_eta_f,n_eta_m,T); 
E_Am=zeros(H_max,n_eta_f,n_eta_m,T); 
val_Af=zeros(H_max,n_eta_f,n_eta_m,T); 
val_Am=zeros(H_max,n_eta_f,n_eta_m,T); 

t=T;

for H=1:H_max
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            
            [val_Af(H,wf,wm,t),val_Am(H,wf,wm,t),~]=autarky(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:), gamma,dt); 
                       
        end
    end
end

t=T-1; 

while (t>=lt)
for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            
            [atkf(H,wf,wm,t),atkm(H,wf,wm,t),~]=autarky(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:), gamma, dt); 
            A(:,:)=Pwage(wf,wm,:,:);
            B_Af(:,:)=val_Af(H,:,:,t+1);
            B_Am(:,:)=val_Am(H,:,:,t+1); 
            E_Af(H,wf,wm,t)=sum(sum(A.*B_Af)); 
            E_Am(H,wf,wm,t)=sum(sum(A.*B_Am));
           val_Af(H,wf,wm,t)=atkf(H,wf,wm,t)+disc*E_Af(H,wf,wm,t);
           val_Am(H,wf,wm,t)=atkm(H,wf,wm,t)+disc*E_Am(H,wf,wm,t);
        end
    end
end

t=t-1;

end


%--- Divorce
coopf=zeros(L_max,H_max,n_eta_f,n_eta_m,T); 
coopm=zeros(L_max, H_max,n_eta_f,n_eta_m,T); 
val_Df=zeros(L_max, H_max,n_eta_f,n_eta_m,T); 
val_Dm=zeros(L_max, H_max,n_eta_f,n_eta_m,T); 

t=T;

for p=1:L_max
for H=1:H_max
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            
            [val_Df(p,H,wf,wm,t),val_Dm(p,H,wf,wm,t),~]=coopD(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:), gamma, lambda_grid(p),dt);
                       
        end
    end
end
end 
 
t=T-1; 

while (t>=lt)
for p=1:L_max
for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            
           [coopf(p,H,wf,wm,t),coopm(p,H,wf,wm,t),~]=coopD(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:), gamma, lambda_grid(p),dt); 
           val_Df(p,H,wf,wm,t)=coopf(p,H,wf,wm,t)+disc*E_Af(H,wf,wm,t);
           val_Dm(p,H,wf,wm,t)=coopm(p,H,wf,wm,t)+disc*E_Am(H,wf,wm,t);
        end
    end
end
end 
t=t-1;

end


%--- Marriage

stayf0=zeros(H_max,n_eta_f,n_eta_m,n_th,T); 
staym0=zeros(H_max,n_eta_f,n_eta_m,n_th,T);
stayf1=zeros(H_max,n_eta_f,n_eta_m,n_th,T);
staym1=zeros(H_max,n_eta_f,n_eta_m,n_th,T);
val=zeros(H_max,n_eta_f,n_eta_m,n_th,T); 
E_Mf0=zeros(H_max,n_eta_f,n_eta_m,n_th,T);
E_Mf1=zeros(H_max,n_eta_f,n_eta_m,n_th,T);
E_Mm0=zeros(H_max,n_eta_f,n_eta_m,n_th,T);
E_Mm1=zeros(H_max,n_eta_f,n_eta_m,n_th,T);
val_matf=zeros(H_max,n_eta_f,n_eta_m,n_th,T,2); 
val_matm=zeros(H_max,n_eta_f,n_eta_m,n_th,T,2); 
dec_M=zeros(H_max,n_eta_f,n_eta_m,n_th,T); 
valM_f=zeros(H_max,n_eta_f,n_eta_m,n_th,T);
valM_m=zeros(H_max,n_eta_f,n_eta_m,n_th,T);
val_f=zeros(H_max,n_eta_f,n_eta_m,n_th,T); 
val_m=zeros(H_max,n_eta_f,n_eta_m,n_th,T);
Awt=zeros(n_eta_f,n_eta_m,n_th);
B_Mf0=zeros(n_eta_f,n_eta_m,n_th); 
B_Mm0=zeros(n_eta_f,n_eta_m,n_th); 
B_Mf1=zeros(n_eta_f,n_eta_m,n_th); 
B_Mm1=zeros(n_eta_f,n_eta_m,n_th);

t=T;

for H=1:H_max
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th    
                
            [stayf0(H,wf,wm,th,t),staym0(H,wf,wm,th,t),stayf1(H,wf,wm,th,t),staym1(H,wf,wm,th,t)]=stayM(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:),home_prod(g,:),psi(g),th1(g),thtau_grid(th),lambda(g),dt); 
            
            end 
        end
    end
end

val_matf(:,:,:,:,t,1)=stayf0(:,:,:,:,t); 
val_matf(:,:,:,:,t,2)=stayf1(:,:,:,:,t);
val_matm(:,:,:,:,t,1)=staym0(:,:,:,:,t); 
val_matm(:,:,:,:,t,2)=staym1(:,:,:,:,t);

[val(:,:,:,:,t),dec_M(:,:,:,:,t)]=max(lambda(g)*val_matf(:,:,:,:,t,:)+(1-lambda(g))*val_matm(:,:,:,:,t,:),[],6);

for H=1:H_max
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
                
                d= dec_M(H,wf,wm,th,t);
                valM_f(H,wf,wm,th,t)= val_matf(H,wf,wm,th,t,d); 
                valM_m(H,wf,wm,th,t)= val_matm(H,wf,wm,th,t,d); 
            
            end
        end
    end
end

%--- Divorce decision and value of the period
val_f(:,:,:,:,t)=valM_f(:,:,:,:,t); 
val_m(:,:,:,:,t)=valM_m(:,:,:,:,t);
[~,x]=min(abs(lambda_grid-lambda(g))); 

for H=1:H_max
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
                if (valM_f(H,wf,wm,th,t) < val_Df(x,H,wf,wm,t)) && (valM_m(H,wf,wm,th,t) < val_Dm(x,H,wf,wm,t));
                           val_f(H,wf,wm,th,t) = val_Df(x,H,wf,wm,t);
                           val_m(H,wf,wm,th,t) = val_Dm(x,H,wf,wm,t);
                           dec_M(H,wf,wm,th,t) = 3; 
                end
                   
                if (valM_f(H,wf,wm,th,t) >= val_Df(x,H,wf,wm,t)) && (valM_m(H,wf,wm,th,t) < val_Dm(x,H,wf,wm,t));
                   for fdx=x+1:L_max 
                       if (valM_f(H,wf,wm,th,t) < val_Df(fdx,H,wf,wm,t)) && (valM_m(H,wf,wm,th,t) < val_Dm(fdx,H,wf,wm,t));
                           val_f(H,wf,wm,th,t) = val_Df(fdx,H,wf,wm,t);
                           val_m(H,wf,wm,th,t) = val_Dm(fdx,H,wf,wm,t);
                           dec_M(H,wf,wm,th,t) = 3; 
                           break 
                       end
                   end
                end
                if (valM_f(H,wf,wm,th,t) < val_Df(x,H,wf,wm,t)) && (valM_m(H,wf,wm,th,t) >= val_Dm(x,H,wf,wm,t));
                   
                   for bkx=x-1:-1:1
                       if (valM_f(H,wf,wm,th,t) < val_Df(bkx,H,wf,wm,t)) && (valM_m(H,wf,wm,th,t) < val_Dm(bkx,H,wf,wm,t));
                           val_f(H,wf,wm,th,t) = val_Df(bkx,H,wf,wm,t);
                           val_m(H,wf,wm,th,t) = val_Dm(bkx,H,wf,wm,t);
                           dec_M(H,wf,wm,th,t) = 3; 
                           break 
                       end
                   end
                end
                
            end
        end
    end
end

t=T-1; 

while (t>lt) 

for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
              
            [stayf0(H,wf,wm,th,t),staym0(H,wf,wm,th,t),stayf1(H,wf,wm,th,t),staym1(H,wf,wm,th,t)]=stayM(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:),home_prod(g,:),psi(g),th1(g),thtau_grid(th),lambda(g), dt); 
            Awt(:,:,:)=P_wg_th(wf,wm,th,:,:,:);
            B_Mf0(:,:,:)=val_f(H,:,:,:,t+1); 
            E_Mf0(H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mf0))); 
            B_Mm0(:,:,:)=val_m(H,:,:,:,t+1); 
            E_Mm0(H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mm0))); 
            
            B_Mf1(:,:,:)=val_f(H+1,:,:,:,t+1); 
            E_Mf1(H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mf1))); 
            B_Mm1(:,:,:)=val_m(H+1,:,:,:,t+1); 
            E_Mm1(H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mm1))); 
            
            val_matf(H,wf,wm,th,t,1)=stayf0(H,wf,wm,th,t) + disc*E_Mf0(H,wf,wm,th,t); 
            val_matf(H,wf,wm,th,t,2)=stayf1(H,wf,wm,th,t) + disc*E_Mf1(H,wf,wm,th,t);
            val_matm(H,wf,wm,th,t,1)=staym0(H,wf,wm,th,t) + disc*E_Mm0(H,wf,wm,th,t); 
            val_matm(H,wf,wm,th,t,2)=staym1(H,wf,wm,th,t) + disc*E_Mm1(H,wf,wm,th,t);
             
            end                                 
        end
    end
end

[val(:,:,:,:,t),dec_M(:,:,:,:,t)]=max(lambda(g)*val_matf(:,:,:,:,t,:)+(1-lambda(g))*val_matm(:,:,:,:,t,:),[],6); 

for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
                
                d= dec_M(H,wf,wm,th,t);
                valM_f(H,wf,wm,th,t)= val_matf(H,wf,wm,th,t,d);
                valM_m(H,wf,wm,th,t)= val_matm(H,wf,wm,th,t,d); 
                
            end
        end
    end
end

val_f(:,:,:,:,t)=valM_f(:,:,:,:,t); 
val_m(:,:,:,:,t)=valM_m(:,:,:,:,t);
[~,x]=min(abs(lambda_grid-lambda(g)));

for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
                if (valM_f(H,wf,wm,th,t) < val_Df(x,H,wf,wm,t)) && (valM_m(H,wf,wm,th,t) < val_Dm(x,H,wf,wm,t));
                           val_f(H,wf,wm,th,t) = val_Df(x,H,wf,wm,t);
                           val_m(H,wf,wm,th,t) = val_Dm(x,H,wf,wm,t);
                            
                           dec_M(H,wf,wm,th,t) = 3; 
                end
                   
                if (valM_f(H,wf,wm,th,t) >= val_Df(x,H,wf,wm,t)) && (valM_m(H,wf,wm,th,t) < val_Dm(x,H,wf,wm,t));
                   for fdx=x+1:L_max 
                       if (valM_f(H,wf,wm,th,t) < val_Df(fdx,H,wf,wm,t)) && (valM_m(H,wf,wm,th,t) < val_Dm(fdx,H,wf,wm,t));
                           val_f(H,wf,wm,th,t) = val_Df(fdx,H,wf,wm,t);
                           val_m(H,wf,wm,th,t) = val_Dm(fdx,H,wf,wm,t);
                           
                           dec_M(H,wf,wm,th,t) = 3; 
                           break 
                       end
                   end
                end
                
                if (valM_f(H,wf,wm,th,t) < val_Df(x,H,wf,wm,t)) && (valM_m(H,wf,wm,th,t) >= val_Dm(x,H,wf,wm,t));
                   
                   for bkx=x-1:-1:1 
                       if (valM_f(H,wf,wm,th,t) < val_Df(bkx,H,wf,wm,t)) && (valM_m(H,wf,wm,th,t) < val_Dm(bkx,H,wf,wm,t));
                           val_f(H,wf,wm,th,t) = val_Df(bkx,H,wf,wm,t);
                           val_m(H,wf,wm,th,t) = val_Dm(bkx,H,wf,wm,t);
                           
                           dec_M(H,wf,wm,th,t) = 3; 
                           break 
                       end
                   end
                end
                
            end
        end
    end
end

t=t-1;
end 

for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
               
            [stayf0(H,wf,wm,th,t),staym0(H,wf,wm,th,t),stayf1(H,wf,wm,th,t),staym1(H,wf,wm,th,t)]=stayM(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:),home_prod(g,:),psi(g),th1(g),0,lambda(g), dt); 
            Awt(:,:,:)=P_wg_th(wf,wm,th,:,:,:);
            B_Mf0(:,:,:)=val_f(H,:,:,:,t+1); 
            E_Mf0(H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mf0))); 
            B_Mm0(:,:,:)=val_m(H,:,:,:,t+1); 
            E_Mm0(H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mm0))); 
            
            B_Mf1(:,:,:)=val_f(H+1,:,:,:,t+1); 
            E_Mf1(H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mf1))); 
            B_Mm1(:,:,:)=val_m(H+1,:,:,:,t+1); 
            E_Mm1(H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mm1))); 
             
            val_matf(H,wf,wm,th,t,1)=stayf0(H,wf,wm,th,t) + disc*E_Mf0(H,wf,wm,th,t); 
            val_matf(H,wf,wm,th,t,2)=stayf1(H,wf,wm,th,t) + disc*E_Mf1(H,wf,wm,th,t);
            val_matm(H,wf,wm,th,t,1)=staym0(H,wf,wm,th,t) + disc*E_Mm0(H,wf,wm,th,t); 
            val_matm(H,wf,wm,th,t,2)=staym1(H,wf,wm,th,t) + disc*E_Mm1(H,wf,wm,th,t);
            
            end                                 
        end
    end
end

[val(:,:,:,:,t),dec_M(:,:,:,:,t)]=max(lambda(g)*val_matf(:,:,:,:,t,:)+(1-lambda(g))*val_matm(:,:,:,:,t,:),[],6); 

for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
                
                d= dec_M(H,wf,wm,th,t);
                val_f(H,wf,wm,th,t)= val_matf(H,wf,wm,th,t,d); 
                val_m(H,wf,wm,th,t)= val_matm(H,wf,wm,th,t,d); 
            
            end
        end
    end
end

dec(g,:,:,:,:,:)=dec_M(:,:,:,:,:); 

U_f(g)=mean(mean(mean(val_f(1,:,:,:,lt)))); 
U_m(g)=mean(mean(mean(val_m(1,:,:,:,lt)))); 

%--- Singles
Sglmale=zeros(n_eta_m,T); 
Em_s=zeros(n_eta_m,T);
valm_s=zeros(n_eta_m,T); 
Bm_s=zeros(n_eta_m,1);

Sglfemale=zeros(n_eta_f,T); 
Ef_s=zeros(n_eta_f,T); 
valf_s=zeros(n_eta_f,T);
Bf_s=zeros(n_eta_f,1);

t=T;

for wm=1:n_eta_m
    valm_s(wm,t)=log(0.61*w(det_earn_m(g,:),t,0,lw_grid_m(wm),dt)); 
end

for wf=1:n_eta_f
    valf_s(wf,t)=log(0.61*w(det_earn_f(g,:),t,0,lw_grid_f(wf),dt)); 
end

t=T-1; 

while (t>=lt)
    for wm=1:n_eta_m
            Sglmale(wm,t)=log(0.61*w(det_earn_m(g,:),t,0,lw_grid_m(wm),dt)); 
            Bm_s(:)=valm_s(:,t+1); 
            Em_s(wm,t)=P_eta_m(wm,:)*Bm_s; 
           valm_s(wm,t)=Sglmale(wm,t)+disc*Em_s(wm,t);
    end
    
    for wf=1:n_eta_f
            Sglfemale(wf,t)=log(0.61*w(det_earn_f(g,:),t,0,lw_grid_f(wf),dt)); 
            Bf_s(:)=valf_s(:,t+1); 
            Ef_s(wf,t)=P_eta_f(wf,:)*Bf_s; 
            valf_s(wf,t)=Sglfemale(wf,t)+disc*Ef_s(wf,t);
    end

t=t-1;

end

U_f0(g)=mean(valf_s(:,lt)) + sigma_f(g); 
U_m0(g)=mean(valm_s(:,lt)) + sigma_m(g); 

end 

U_ftot(1)=exp(U_f(1))+exp(U_f(2))+exp(U_f(3))+exp(U_f0(1));
U_ftot(2)=U_ftot(1);
U_ftot(3)=U_ftot(1);

U_ftot(4)=exp(U_f(4))+exp(U_f(5))+exp(U_f(6))+exp(U_f0(4));
U_ftot(5)=U_ftot(4);
U_ftot(6)=U_ftot(4);

U_ftot(7)=exp(U_f(7))+exp(U_f(8))+exp(U_f(9))+exp(U_f0(7));
U_ftot(8)=U_ftot(7);
U_ftot(9)=U_ftot(7);

U_mtot(1)=exp(U_m(1))+exp(U_m(4))+exp(U_m(7))+exp(U_m0(1));
U_mtot(4)=U_mtot(1);
U_mtot(7)=U_mtot(1);

U_mtot(2)=exp(U_m(2))+exp(U_m(5))+exp(U_m(8))+exp(U_m0(2));
U_mtot(5)=U_mtot(2);
U_mtot(8)=U_mtot(2);

U_mtot(3)=exp(U_m(3))+exp(U_m(6))+exp(U_m(9))+exp(U_m0(3));
U_mtot(6)=U_mtot(3);
U_mtot(9)=U_mtot(3);

end

